import pysam
import pybedtools
from pylab import *


datasets = ("MiSeq", "HiSeq")
timepoints = (0, 1, 4, 12, 24, 96)
replicates = (1, 2, 3)

minimum = sys.maxsize
maximum = -1
filename = "dbi.bed"
promoters = pybedtools.BedTool(filename)
for promoter in promoters:
    start = promoter.start
    end = promoter.end
    minimum = min(minimum, start)
    maximum = max(maximum, end)

totals = {}
counts = {}
for dataset in datasets:
    counts[dataset] = zeros((6, 3, maximum-minimum), int)
    totals[dataset] = zeros((6, 3), int)

for i, timepoint in enumerate(timepoints):
    for j, replicate in enumerate(replicates):
        if timepoint == 1 and replicate == 3:
            # HiSeq negative control using water instead of RNA as starting material
            continue
        library = "t%02d_r%d" % (timepoint, replicate)
        for dataset in datasets:
            filename = "%s.%s.sam" % (dataset, library)
            print("Reading", filename)
            alignments = pysam.AlignmentFile(filename)
            count = zeros(maximum-minimum, int)
            for alignment in alignments:
                nh = alignment.get_tag('NH')
                assert nh == 1
                tss = alignment.reference_start
                if alignment.query[0] == "G":
                    operation, length = alignment.cigar[0]
                    if operation == pysam.CMATCH:
                        md = alignment.get_tag("MD")
                        if md[0] == '0':  # mismatched G at the first position
                            tss += 1
                if tss < minimum:
                    continue
                if tss >= maximum:
                    continue
                k = tss - minimum
                count[k] += 1
            counts[dataset][i, j, :] = count

for dataset in ("MiSeq", "HiSeq"):
    filename = "annotations.%s.txt" % dataset
    print("Reading", filename)
    stream = open(filename)
    line = next(stream)
    words = line.split()
    words = line.strip().split("\t")
    assert words[:3] == ['#rank', 'annotation', 'transcript']
    libraries = words[3:]
    n = len(libraries)
    total = zeros(len(libraries))
    for line in stream:
        words = line.strip().split("\t")
        assert len(words[3:]) == n
        total += array(words[3:], int)
    stream.close()
    for i, timepoint in enumerate(timepoints):
        for j, replicate in enumerate(replicates):
            if timepoint == 1 and replicate == 3:
                # HiSeq negative control using water instead of RNA as starting material
                continue
            library = "t%02d_r%d" % (timepoint, replicate)
            k = libraries.index(library)
            totals[dataset][i, j] = total[k]

for i, dataset in enumerate(datasets):
    subplot(211+i)
    if dataset == "MiSeq":
        color = 'red'
    elif dataset == 'HiSeq':
        color = 'blue'
    plot(100*sum(counts[dataset][:, 0, :], 1)/totals[dataset][:, 0], marker='o', linestyle='none', color=color, label="replicate 1")
    plot(100*sum(counts[dataset][:, 1, :], 1)/totals[dataset][:, 1], marker='s', linestyle='none', color=color, label="replicate 2")
    plot(100*sum(counts[dataset][:, 2, :], 1)/totals[dataset][:, 2], marker='*', linestyle='none', color=color, label="replicate 3")
    ylabel("Percentage of sequenced reads", fontsize=8)
    title(dataset, fontsize=8)
    legend(fontsize=8)
    xticks(arange(6), [])
    yticks(fontsize=8)

xticks(arange(6), ['%d hr' % timepoint for timepoint in timepoints], fontsize=8)
xlabel("Time point", fontsize=8)

filename = "figure_dbi_overrepresentation.svg"
print("Saving figure to %s" % filename)
savefig(filename)

filename = "figure_dbi_overrepresentation.png"
print("Saving figure to %s" % filename)
savefig(filename)


fig = figure(figsize=(6, 12))
subplots_adjust(left=0.16,right=0.92,wspace=0.7,top=0.97,bottom=0.14)

ax = fig.add_subplot(111)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
xticks([])
yticks([])
ax.set_xlabel("Position on chromosome 2", fontsize=8, labelpad=65)
ax.set_ylabel("HiSeq tag count", color='blue', labelpad=42, fontsize=8)
ax2 = twinx()
ax2.spines['top'].set_color('none')
ax2.spines['bottom'].set_color('none')
ax2.spines['left'].set_color('none')
ax2.spines['right'].set_color('none')
xticks([])
yticks([])
ax2.set_ylabel("MiSeq tag count", color='red', labelpad=24, fontsize=8)

for i, timepoint in enumerate(timepoints):
    ax = fig.add_subplot(611 + i)
    ax.spines['top'].set_color('none')
    ax.spines['bottom'].set_color('none')
    ax.spines['left'].set_color('none')
    ax.spines['right'].set_color('none')
    ax.tick_params(labelcolor='w', top=False, bottom=False, left=False, right=False)
    ax.set_ylabel("%d hour" % timepoint, labelpad=34, fontsize=8)

x = arange(minimum, maximum)
bins = arange(minimum-0.5, maximum-0.5)
positions = arange(minimum, maximum, 10)
for i, timepoint in enumerate(timepoints):
    for j, replicate in enumerate(replicates):
        if timepoint == 1 and replicate == 3:
            # HiSeq negative control using water instead of RNA as starting material
            continue
        ax = fig.add_subplot(6, 3, 3*i+j+1)
        hist(x, weights=counts['HiSeq'][i, j, :], bins=bins, color='blue', histtype='step')
        yticks(fontsize=8, color='blue')
        if timepoint == 96:
            labels = [str(position) for position in positions]
            xticks(positions, labels, rotation=90, fontsize=8)
        else:
            xticks(positions, [])
        xlim(minimum, maximum)
        ax2 = twinx()
        hist(x, weights=counts['MiSeq'][i, j, :], bins=bins, color='red', histtype='step')
        yticks(fontsize=8, color='red')
        xlim(minimum, maximum)
        if timepoint == 0:
            title("Replicate %d" % replicate, fontsize=8)

filename = "figure_dbi_overrepresentation_promoter.svg"
print("Saving figure to %s" % filename)
savefig(filename)

filename = "figure_dbi_overrepresentation_promoter.png"
print("Saving figure to %s" % filename)
savefig(filename)
